import torch
import random
import sampling as stg
from PointerNet import GATEncoder, MHADecoder
from argparse import ArgumentParser
from inout import load_dataset
from tqdm import tqdm
import torch.nn.utils as utils 
from utils import *
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# Training device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Number of steps to wait before probing for improvements
PROBE_EVERY = 2500


@torch.no_grad()
def validation(encoder: torch.nn.Module,
               decoder: torch.nn.Module,
               val_set: list,
               num_sols: int = 16,
               lc_dim: int=4,
               ld_dim: int=12):
    """
    Test the model at the end of each epoch.

    Args:
        encoder: Encoder.
        decoder: Decoder.
        val_set: Validation set.
        num_sols: Number of solution to generate for each instance.
        seed: Random seed.
    """
    # if seed is not None:
    #     torch.manual_seed(seed)
    encoder.eval()
    decoder.eval()
    gaps = ObjMeter()

    # For each instance in the benchmark
    for ins in val_set:
        # Sample multiple solutions
        lc_var = torch.empty(num_sols, lc_dim).uniform_(-1,1)

        ld_var = torch.zeros((num_sols, ld_dim), dtype=torch.float32)
        one_hot_idx = torch.randint(0, ld_dim, (num_sols,))
        ld_var[torch.arange(num_sols), one_hot_idx] = 1

        l_var = torch.cat([ld_var, lc_var], dim=-1).to(device)

        s, mss = stg.sampling(ins, encoder, decoder, l_var, bs=num_sols, device=device)

        # Log info
        min_gap = (mss.min().item() / ins['makespan'] - 1) * 100
        gaps.update(ins, min_gap)

    # Print stats
    avg_gap = gaps.avg
    print(f"\t\tVal set: AVG Gap={avg_gap:.3f}")
    print(gaps)
    return avg_gap


def train(encoder: torch.nn.Module,
          decoder: torch.nn.Module,
          train_set: list,
          val_set: list,
          epochs: int = 50,
          virtual_bs: int = 16,
          num_sols: int = 256,
          lc_dim: int=4,
          ld_dim: int=12,
          model_path: str = 'checkpoints/PointerNet.pt'): #infonet:torch.nn.Module,
    """
    Train the Pointer Network.

    Args:
        encoder: Encoder to train.
        decoder: Decoder to train.
        train_set: Training set.
        val_set: Validation set.
        epochs: Number of epochs.
        virtual_bs: Virtual batch size that gives the number of instances
            predicted before back-propagation.
        num_sols: Number of solutions to use in back-propagation.
        model_path:
    """
    frac, _best = 1. / virtual_bs, None
    size = len(train_set)
    indices = list(range(size))
    ### OPTIMIZER
    opti = torch.optim.Adam(list(_enc.parameters()) +
                            list(_dec.parameters())
                            , lr=args.lr)

    print("Training ...")
    step=0
    for epoch in range(0, epochs):
        losses = AverageMeter()
        temp_losses = AverageMeter()
        gaps = ObjMeter()
        random.shuffle(indices)
        cnt = 0
        
        loss_list=list()
        val_score_list=list()
        for idx, i in tqdm(enumerate(indices)):
            ins = train_set[i]
            cnt += 1

            num_j, num_m = ins['j'], ins['m']
            num_ops = num_j * num_m - 1

            lc_var = torch.empty(num_sols, lc_dim).uniform_(-1,1).to(device)
            ld_var = torch.zeros((num_sols, ld_dim), dtype=torch.float32).to(device)

            one_hot_idx = torch.randint(0, ld_dim, (num_sols,))
            ld_var[torch.arange(num_sols), one_hot_idx] = 1

            l_var = torch.cat([ld_var, lc_var], dim=-1).to(device) # [batch, latent_dim]

            # Training step (sample solutions)
            trajs, probs, mss = stg.sample_training(ins, encoder, decoder, l_var,
                                                     bs=num_sols, device=device)
            ms, argmin = mss.min(-1)
            # ############################################ Self labeling learning
            mean_ms = mss.mean()
            ms_var = mss.std(unbiased=False)
            ms_weight = torch.abs((ms - mean_ms) / ms_var)  

            batch_loss = -torch.log(probs[argmin]).mean()
            loss = ms_weight*batch_loss

            losses.update(loss.item())
            temp_losses.update(loss.item())
            gaps.update(ins, (ms.item() / ins['makespan'] - 1) * 100)

            # Virtual batching for managing without masking different sizes
            loss *= frac
            loss.backward()
            if cnt == virtual_bs or idx + 1 == size:
                opti.step()
                opti.zero_grad()
                loss_avg = temp_losses.avg
                loss_list.append(loss_avg)
                temp_losses = AverageMeter() #initialization
                cnt = 0

            # Probe model
            if idx > 0 and idx % PROBE_EVERY == 0:
                step+=1
                val_gap = validation(encoder, decoder, val_set, num_sols=256)
                val_score_list.append(val_gap)
                if _best is None or val_gap < _best:
                    _best = val_gap
                    torch.save((encoder.state_dict(), decoder), model_path)
        # ...log the running loss
        avg_loss, avg_gap = losses.avg, gaps.avg
        logger.train(epoch, avg_loss, avg_gap)
        print(f'\tEPOCH {epoch:02}: avg loss={losses.avg:.4f}')
        print(f"\t\tTrain: AVG Gap={avg_gap:2.3f}")
        print(gaps)

        # Test model and save
        val_gap = validation(encoder, decoder, val_set, num_sols=128)
        logger.validation(val_gap)
        step+=1
        if _best is None or val_gap < _best:
            _best = val_gap
            torch.save((encoder.state_dict(), decoder), model_path)
        logger.flush()
    print("loss_list: ", loss_list)
    print("val_score_list: ", val_score_list)
#
parser = ArgumentParser(description='PointerNet arguments for the JSP')
parser.add_argument("-data_path", type=str, default="./dataset5k",
                    required=False, help="Path to the training data.")
parser.add_argument("-model_path", type=str, required=False,
                    default=None, help="Path to the model.") #default=None
parser.add_argument("-enc_hidden", type=int, default=64, required=False,
                    help="Hidden size of the encoder.")
parser.add_argument("-enc_out", type=int, default=128, required=False,
                    help="Output size of the encoder.")
parser.add_argument("-mem_hidden", type=int, default=64, required=False,
                    help="Hidden size of the memory network.")
parser.add_argument("-mem_out", type=int, default=128, required=False,
                    help="Output size of the memory network.")
parser.add_argument("-clf_hidden", type=int, default=128, required=False,
                    help="Hidden size of the classifier.") 
parser.add_argument("-lr", type=float, default=0.0001, required=False,
                    help="Learning rate in the first checkpoint.")
parser.add_argument("-epochs", type=int, default=20, required=False,
                    help="Number of epochs.")
parser.add_argument("-bs", type=int, default=16, required=False,
                    help="Virtual batch size.")
parser.add_argument("-beta", type=int, default=64, required=False,
                    help="Number of sampled solutions.")
parser.add_argument("-latent_cont_dim", type=float, default=4, required=False,
                    help="dimension of continuous latent variable")
parser.add_argument("-latent_disc_dim", type=float, default=12, required=False,
                    help="dimension of categorical latent variable")

args = parser.parse_args()
print(args)
run_name = f"Init_PtrNet-BS{args.bs}-B{args.beta}"
logger = Logger(run_name)

if __name__ == '__main__':
    print(f"Using device: {device}")

    ### TRAINING and VALIDATION
    train_set = load_dataset(args.data_path)
    val_set = load_dataset('./benchmarks/validation', device=device)
    ### MAKE MODEL
    _enc = GATEncoder(train_set[0]['x'].shape[1],
                      hidden_size=args.enc_hidden,
                      embed_size=args.enc_out).to(device)
    
    _dec = MHADecoder(encoder_size=_enc.out_size,
                      context_size=stg.JobShopStates.size,
                      hidden_size=args.mem_hidden,
                      lc_dim=args.latent_cont_dim,
                      ld_dim = args.latent_disc_dim,
                      mem_size=args.mem_out,
                      clf_size=args.clf_hidden).to(device)
    
    # Load model if necessary
    if args.model_path is not None:
        print(f"Loading {args.model_path}.")
        m_path = f"{args.model_path}"
        _enc_w, _dec = torch.load(args.model_path, map_location=device)
        _enc.load_state_dict(_enc_w)
    else:
        m_path = f"checkpoints/{run_name}.pt"
    print(_enc)
    print(_dec)

    train(_enc, _dec, train_set, val_set,
          epochs=args.epochs,
          virtual_bs=args.bs,
          num_sols=args.beta,
          lc_dim=args.latent_cont_dim,
          ld_dim = args.latent_disc_dim,
          model_path=m_path)